{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import scipy.io as sio                     # import scipy.io for .mat file I/O \n",
    "import numpy as np                         # import numpy\n",
    "import matplotlib.pyplot as plt            # import matplotlib.pyplot for figure plotting\n",
    "import function_wmmse_powercontrol as wf\n",
    "import tensorflow as tf"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Gaussian IC Case: K=20, Total Samples: 2000, Total Iterations: 50\n",
      "\n"
     ]
    }
   ],
   "source": [
    "K = 20                  # number of users\n",
    "num_H = 2000           # number of training samples\n",
    "num_test = 500            # number of testing  samples\n",
    "training_epochs = 50      # number of training epochs\n",
    "trainseed = 0              # set random seed for training set\n",
    "testseed = 7               # set random seed for test set\n",
    "print('Gaussian IC Case: K=%d, Total Samples: %d, Total Iterations: %d\\n'%(K, num_H, training_epochs))\n",
    "var_db = 0\n",
    "var = 1/10**(var_db/10)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "import time\n",
    "def generate_wGaussian(K, num_H, var_noise=1, Pmin=0, seed=2017):\n",
    "    print('Generate Data ... (seed = %d)' % seed)\n",
    "    np.random.seed(seed)\n",
    "    Pmax = 1\n",
    "    Pini = Pmax*np.ones(K)\n",
    "    #alpha = np.random.rand(num_H,K)\n",
    "    alpha = np.ones((num_H,K))\n",
    "    #var_noise = 1\n",
    "    X=np.zeros((K**2,num_H))\n",
    "    Y=np.zeros((K,num_H))\n",
    "    total_time = 0.0\n",
    "    for loop in range(num_H):\n",
    "        CH = 1/np.sqrt(2)*(np.random.randn(K,K)+1j*np.random.randn(K,K))\n",
    "        H=abs(CH)\n",
    "        X[:,loop] = np.reshape(H, (K**2,), order=\"F\")\n",
    "        H=np.reshape(X[:,loop], (K,K), order=\"F\")\n",
    "        mid_time = time.time()\n",
    "        Y[:,loop] = wf.WMMSE_sum_rate2(Pini, alpha[loop,:], H, Pmax, var_noise)\n",
    "        total_time = total_time + time.time() - mid_time\n",
    "    \n",
    "    # print(\"wmmse time: %0.2f s\" % total_time)\n",
    "    return X, Y, alpha, total_time"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Generate Data ... (seed = 0)\n",
      "Generate Data ... (seed = 7)\n",
      "(2000, 400) (2000, 20)\n"
     ]
    }
   ],
   "source": [
    "Xtrain, Ytrain, Atrain, wtime = generate_wGaussian(K, num_H, seed=trainseed, var_noise = var)\n",
    "X, Y, A, wmmsetime = generate_wGaussian(K, num_test, seed=testseed, var_noise = var)\n",
    "Xtrain = Xtrain.transpose()\n",
    "X = X.transpose()\n",
    "Ytrain = Ytrain.transpose()\n",
    "Y = Y.transpose()\n",
    "print(Xtrain.shape,Ytrain.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "Xtrain = Xtrain.reshape((-1,K,K))\n",
    "X = X.reshape((-1,K,K))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "def extract_features(H,n,L,alpha):\n",
    "    direct_H = np.zeros((n,L))\n",
    "    inter_to = np.zeros((n,L,L))\n",
    "    inter_from = np.zeros((n,L,L))\n",
    "    other_H = np.zeros((n,L,L))\n",
    "    for ii in range(n):\n",
    "        diag_H = np.diag(H[ii,:,:])\n",
    "        for jj in range(L):\n",
    "            direct_H[ii,jj] = H[ii,jj,jj]\n",
    "            inter_to[ii,jj,:] = H[ii,:,jj].T\n",
    "            inter_to[ii,jj,jj] = 0\n",
    "            inter_from[ii,jj,:] = H[ii,jj,:]\n",
    "            inter_from[ii,jj,jj] = 0\n",
    "            other_H[ii,jj,:] = diag_H\n",
    "            other_H[ii,jj,jj] = 0\n",
    "    return direct_H, inter_to, inter_from, other_H, alpha\n",
    "\n",
    "features = extract_features(Xtrain,num_H,K,Atrain)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "def extract_labels(y,n,L):\n",
    "    labels = np.zeros((n,L,2) )\n",
    "    for ii in range(n):\n",
    "        for jj in range(L):\n",
    "            if(abs(y[ii][jj]) < 1e-4):\n",
    "                labels[ii,jj,:] = [1,0]\n",
    "            else:\n",
    "                labels[ii,jj,:] = [0,1]\n",
    "    return labels\n",
    "labels = extract_labels(Ytrain,num_H,K)\n",
    "labels_t = extract_labels(Y,num_test,K)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "weights = {\n",
    "    'wif1': tf.Variable(tf.random_normal([1, 1, 5, 16],stddev=0.1)),\n",
    "    'wif2': tf.Variable(tf.random_normal([1, 1, 16, 16],stddev=0.1)),\n",
    "    'wif3': tf.Variable(tf.random_normal([1, 1, 16, 6],stddev=0.1)),\n",
    "    \n",
    "    'wfc1': tf.Variable(tf.random_normal([15, 32],stddev=0.1)),\n",
    "    'wfc2': tf.Variable(tf.random_normal([32, 16],stddev=0.1)),\n",
    "    'wfc3': tf.Variable(tf.random_normal([16, 1])),\n",
    "}\n",
    "\n",
    "biases = {\n",
    "    'bif1': tf.Variable(tf.random_normal([16],stddev=0.1)),\n",
    "    'bif2': tf.Variable(tf.random_normal([16],stddev=0.1)),\n",
    "    'bif3': tf.Variable(tf.random_normal([6],stddev=0.1)),\n",
    "    \n",
    "    'bfc1': tf.Variable(tf.random_normal([32],stddev=0.1)),\n",
    "    'bfc2': tf.Variable(tf.random_normal([16],stddev=0.1)),\n",
    "    'bfc3': tf.Variable(tf.random_normal([1])),\n",
    "    \n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "def conv2d(x, W, b, strides=1):\n",
    "    # Conv2D wrapper, with bias and relu activation\n",
    "    x = tf.nn.conv2d(x, W, strides=[1, strides, strides, 1], padding='SAME')\n",
    "    x = tf.nn.bias_add(x, b)\n",
    "    return tf.nn.relu(x)\n",
    "\n",
    "def inter_conv_net(Ws,bs,inter):\n",
    "    h1 = conv2d(inter,Ws[0],bs[0])\n",
    "    h2 = conv2d(h1,Ws[1],bs[1])\n",
    "    h3 = conv2d(h2,Ws[2],bs[2])\n",
    "    fea_mean = tf.reduce_sum(h3,axis=2)\n",
    "    fea_max = tf.reduce_max(h3,axis=2)\n",
    "    fea = tf.concat([fea_mean,fea_max],axis=2)\n",
    "    return fea\n",
    "\n",
    "def fully_connected_net(Ws,bs,feat):\n",
    "    hidden1 = tf.nn.relu( tf.tensordot(feat,Ws[0],[[2], [0]]) + bs[0])\n",
    "    hidden2 = tf.nn.relu( tf.tensordot(hidden1,Ws[1],[[2], [0]]) + bs[1])\n",
    "    out = tf.tensordot(hidden2,Ws[2] ,[[2], [0]]) + bs[2]\n",
    "    return out"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "L = K\n",
    "Xinterf = tf.placeholder(tf.float32, [None, L, L, 1])\n",
    "Htrain = tf.placeholder(tf.float32, [None, L, L])\n",
    "Xintert = tf.placeholder(tf.float32, [None, L, L, 1])\n",
    "Xdiag = tf.placeholder(tf.float32, [None, L, 1])\n",
    "Xdiag_o = tf.placeholder(tf.float32, [None, L, L, 1])\n",
    "intensity = tf.placeholder(tf.float32, [None, L, 1])\n",
    "w_alpha = tf.placeholder(tf.float32, [None, L, 1])\n",
    "y_ = tf.placeholder(tf.float32, [None, L, 1])\n",
    "\n",
    "Winterf = [weights['wif1'],weights['wif2'],weights['wif3']]\n",
    "binterf = [biases['bif1'],biases['bif2'],biases['bif3']]\n",
    "\n",
    "Wfc = [weights['wfc1'],weights['wfc2'],weights['wfc3']]\n",
    "bfc = [biases['bfc1'],biases['bfc2'],biases['bfc3']]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "intens = intensity\n",
    "\n",
    "all_one = tf.ones((1,L),dtype=tf.float32)\n",
    "intens2 = tf.tensordot(intens, all_one, [[2], [0]]) #tf.matmul(intens,all_one)\n",
    "w2 = tf.tensordot(w_alpha, all_one, [[2], [0]])\n",
    "\n",
    "intens2 = tf.transpose(intens2, perm=[0, 2, 1])\n",
    "intens2 = tf.reshape(intens2, (-1,L,L,1))\n",
    "\n",
    "w2 = tf.transpose(w2, perm=[0, 2, 1])\n",
    "w2 = tf.reshape(w2, (-1,L,L,1))\n",
    "\n",
    "w_alpha2 = tf.reshape(w_alpha,(-1,L))\n",
    "\n",
    "xf = tf.concat((Xinterf,Xintert,Xdiag_o,intens2,w2), axis=3)\n",
    "\n",
    "for ii in range(5):\n",
    "    fea1 = inter_conv_net(Winterf,binterf,xf)\n",
    "    fea = tf.concat([Xdiag, intens, fea1, w_alpha], axis=2)#\n",
    "    out = fully_connected_net(Wfc,bfc,fea)\n",
    "    pred = tf.nn.sigmoid(out) \n",
    "    intens = pred\n",
    "    intens2 = tf.tensordot(intens, all_one, [[2], [0]])\n",
    "    intens2 = tf.transpose(intens2, perm=[0, 2, 1])\n",
    "    intens2 = tf.reshape(intens2, (-1,L,L,1))\n",
    "    xf = tf.concat((Xinterf,intens2,Xintert,Xdiag_o,w2), axis=3)\n",
    "    \n",
    "pred = tf.reshape(pred,(-1,K))\n",
    "H = tf.math.square(Htrain)\n",
    "H_diag = tf.matrix_diag_part(H)\n",
    "fr = H_diag * pred\n",
    "pred2 = tf.reshape(pred,[-1,K,1])\n",
    "ag = tf.reshape(tf.matmul(H,pred2),[-1,K]) + var - fr\n",
    "obj = tf.reduce_sum(w_alpha2*tf.log(1+fr/ag),axis=1 )\n",
    "cost2 = -tf.reduce_mean(obj,axis=0)\n",
    "\n",
    "train_step = tf.train.AdamOptimizer(1e-3).minimize(cost2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(500, 20, 1)\n"
     ]
    }
   ],
   "source": [
    "def get_batch(H, features, labels, num_train, batch_size, L, is_test=False):\n",
    "    if(is_test):\n",
    "        idx = np.array(range(batch_size))\n",
    "    else:    \n",
    "        idx = np.random.randint(num_train,size=batch_size)\n",
    "    a0 = H[idx,:,:]\n",
    "    a = np.reshape(features[0][idx,:],(batch_size,L,1))\n",
    "    b = np.reshape(features[1][idx,:,:],(batch_size,L,L,1))\n",
    "    c = np.reshape(features[2][idx,:,:],(batch_size,L,L,1))\n",
    "    d = np.reshape(features[3][idx,:,:],(batch_size,L,L,1))\n",
    "    f = np.reshape(features[4][idx,:],(batch_size,L,1))\n",
    "    e = np.reshape(labels[idx,:],(batch_size,L,1))\n",
    "    \n",
    "    return a,b,c,d,e,f,a0\n",
    "def IC_sum_rate(H, alpha, p, var_noise):\n",
    "    H = np.square(H)\n",
    "    fr = np.diag(H)*p\n",
    "    ag = np.dot(H,p) + var_noise - fr\n",
    "    y = np.sum(alpha * np.log(1+fr/ag) )\n",
    "    return y\n",
    "def np_sum_rate(X,Y,alpha):\n",
    "    avg = 0\n",
    "    n = X.shape[0]\n",
    "    for i in range(n):\n",
    "        avg += IC_sum_rate(X[i,:,:],alpha[i,:],Y[i,:],1)/n\n",
    "    return avg\n",
    "\n",
    "\n",
    "test_fea = extract_features(X,num_test,K,A)\n",
    "di_t, intert_t, interf_t, diag_t, labels_t, alpha_t, HHH_t = get_batch(X,test_fea,Y,num_test,num_test,K,is_test=True)\n",
    "print(labels_t.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2.4887466100001423\n",
      "-0.99315524\n",
      "-2.5160837\n",
      "-2.5602236\n",
      "-2.5619988\n"
     ]
    }
   ],
   "source": [
    "print(np_sum_rate(X,Y,A))\n",
    "sess = tf.InteractiveSession()\n",
    "tf.global_variables_initializer().run()\n",
    "batch_size = 50\n",
    "total = num_H*1000\n",
    "for ii in range(total//batch_size):\n",
    "    \n",
    "    di, intert, interf, diag, batch_ys, alpha_tr, HHH = get_batch(Xtrain,features,Ytrain, num_H, batch_size,K)\n",
    "    start_state = np.ones((batch_size,L,1))\n",
    "    sess.run(train_step, feed_dict={Htrain:HHH, Xinterf: interf, Xintert: intert, Xdiag: di, Xdiag_o: diag, y_: batch_ys, intensity: start_state, w_alpha: alpha_tr})\n",
    "    if(ii%10000 == 0):\n",
    "        start_state = np.ones((num_test,L,1))\n",
    "        cost_value = sess.run(cost2, feed_dict={Htrain:HHH_t, Xinterf: interf_t, Xintert: intert_t, Xdiag: di_t, Xdiag_o: diag_t, y_: labels_t, intensity: start_state, w_alpha: alpha_t})\n",
    "        print(cost_value) "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "sum rate for IGCNet 3.6597809502824763\n",
      "sum rate for WMMSE 3.590502392276424\n"
     ]
    }
   ],
   "source": [
    "start_state = np.ones((num_test,L,1))\n",
    "cost_value = sess.run(pred2, feed_dict={Htrain:HHH_t, Xinterf: interf_t, Xintert: intert_t, Xdiag: di_t, Xdiag_o: diag_t, y_: labels_t, intensity: start_state, w_alpha: alpha_t}) \n",
    "pred = np.reshape(cost_value,(500,K))\n",
    "\n",
    "print('sum rate for IGCNet', np_sum_rate(X,pred,A)*np.log2(np.exp(1)))\n",
    "print('sum rate for WMMSE', np_sum_rate(X,Y,A)*np.log2(np.exp(1)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
